import torch

import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from Modules import *
from LinearModule_utils import *
from activations import *



#------------------------------------------------------------------#
#embedding structure [ signal, memory, position ] 
#I will introduce k blanks to separate different sub-sequences
#------------------------------------------------------------------#



#------------------------------------------------------------------#
#config contains the following important parameters: 
#config.signal_start : Start Index of current signal embeddings (0 always)
#config.signal_end : End Index of current signal
#config.memory_start : Start index of memorized embeddings (from a previous layer)
#config.memory_end : End Index of memorized embeddings (from a previous layer)
#config.position_start : Start index of one-hot position embeddings
#config.seq_length : Sequence length of the smaller model that we are trying to simulate
#config.blank_identifier : Index containing Identifiers for blank token
#config.num_blanks : Number of blanks to separate the sub-sequences
#config.num_attention_heads : Number of attention heads
#config.scale_embeddings : A scale to initialize different query, key matrices
#config.inner_lr : Inner learning rate to simulate sgd  
#------------------------------------------------------------------# 
class ActivationForward (nn.Module):
    def __init__ (self, config, din, projection_matrix=None, memory_index=-1):
        super(ActivationForward, self).__init__()
        
        self.din=din
        self.config=config
        self.projection_matrix = projection_matrix
        self.memory_index = memory_index
        
        
        if projection_matrix is not None:
            self.dout = projection_matrix.shape[0]
        else:
            self.dout = din
        
            
        assert memory_index == -1 or memory_index >= self.dout,\
              "Memory interacts with final signal"
        
        assert memory_index == -1 or memory_index <= config.hidden_size - self.din, \
               "not enough space to store memory"
        
        if projection_matrix is not None:
            head_dim = self.dout
            num_channels = config.hidden_size // head_dim
        else:
            num_channels = config.num_attention_heads
            head_dim = config.hidden_size // num_channels
        
        self.mlp_module = MLP (config.hidden_size, \
                               config, \
                               conv2d=True, \
                               transpose_intermediate=True, \
                               transpose_proj=False, \
                               conv_proj_features=num_channels, \
                              )
        
        self.mlp_gates = Gates (config)
        self.projection_ = None
        
                            
        if projection_matrix is not None:
                
            assert projection_matrix.shape[1] == din,\
                   "Projection matrix must have 'din' in second coordinate"
            assert projection_matrix.shape[1] >= head_dim, \
                   "Currently, this projection only works when we project down to a lower dimension"
            assert projection_matrix.shape[1] % head_dim == 0, \
                   "Perfect division into channels"
            
            c_proj_init = torch.zeros((num_channels, head_dim, head_dim), dtype=self.mlp_module.c_proj.weight.dtype)
            num_useful_channels = projection_matrix.shape[1] // head_dim
            for i in range (num_useful_channels):
                c_proj_init[i] = torch.tensor(projection_matrix[:, i*head_dim: (i+1)*head_dim], dtype=self.mlp_module.c_proj.weight.dtype)
            self.mlp_module.initialize_weights(c_proj_init=c_proj_init)    
            
            self.projection_ = Conv2D( nf=num_channels, nx=head_dim, transpose=True, use_einsum=self.config.use_einsum )
            with torch.no_grad():    
                self.projection_.weight.copy_(torch.zeros(head_dim, num_channels, num_channels))
                self.projection_.weight[:, :num_useful_channels, 0] = 1.
            
        else:
            c_proj_init = torch.zeros((num_channels, head_dim, head_dim), dtype=self.mlp_module.c_proj.weight.dtype)
            
            if self.memory_index != -1:
                assert memory_index % head_dim == 0, \
                       "Memory should be divisible by the number of channels!"

                mem_head_start = memory_index // head_dim

                c_proj_init[:mem_head_start] = torch.eye(head_dim)
                self.mlp_module.initialize_weights(c_proj_init=c_proj_init)  
            else:
                c_proj_init[:] = torch.eye(head_dim)
                self.mlp_module.initialize_weights(c_proj_init=c_proj_init)  
                
        #Initialize Gates
        #Ignore the changes for the blanks!
        #w, u, v, w_bias, u_bias, v_bias
        w = torch.zeros((1, 2*config.hidden_size))
        u = torch.zeros((1, 2*config.hidden_size))
        v = torch.zeros((1, 2*config.position_dim))
        w_bias = torch.zeros(2)
        u_bias = torch.zeros(2)
        v_bias = torch.zeros(2)

        #Input Gate is 1 on blanks and 0 for non-blanks
        v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.position_dim-config.seq_length)


        #Change Gate is 0 on blanks and 1 for non-blanks
        v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.position_dim-config.seq_length)
        v_bias [1] += config.gate_scale

        self.mlp_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
        
        #self.add_module('MLPforward_mlp', self.mlp_module)
        #self.add_module('MLPforward_gates', self.mlp_gates)
        
        
    def forward(self, hidden_states, position_embeddings):
        #print ("----", torch.sum(torch.absolute(hidden_states[:, self.config.num_blanks:, self.memory_index:])).item(), "-----")
        mlp_out = self.mlp_module.forward(hidden_states)
        if self.projection_ is not None:
            mlp_out = self.projection_(mlp_out)
        
        if self.memory_index != -1:
            assert torch.sum(torch.absolute(mlp_out[:, self.config.num_blanks:, self.memory_index:])).item() < 1e-10,\
                   "Memory portion not empty!"

            mlp_out[:, self.config.num_blanks:, self.memory_index: self.memory_index+self.din] += hidden_states[:, self.config.num_blanks:, :self.din]

        gate_out = self.mlp_gates.forward(hidden_states, mlp_out, position_embeddings)
           
        return gate_out
    

    
class ActivationBackward (nn.Module):
    def __init__ (self, config, din, input_projection=None, projection_matrix=None, memory_index=-1, retain_og_act=False):
        super(ActivationBackward, self).__init__()

        
        
        assert memory_index == -1 or memory_index >= din, \
            "memory crosses current signal"
    

        
        self.epsilon = config.epsilon
        self.memory_index = memory_index
        self.config = config
        
        head_dim  = config.hidden_size // config.num_attention_heads
        self.c_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum)
        self.proj_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum)
        
        self.config = config
        self.din = din
        self.act = ACT2FN[config.activation_function]
        
        
        c_fc_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
        c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
        
    
        #compute x + \epsilon \nabla y
        #c_fc_init[din: 2*din, din: 2*din] = 1./config.scale_embeddings * torch.eye(din)
        #c_fc_init[din: 2*din, self.memory_index + din: self.memory_index + 2*din] = torch.eye(din)
        
        assert din % head_dim == 0, \
            " 'din' should be a multiple of head_dim! "
        
        num_partitions = din // head_dim
        #print (num_partitions)
        
        
        
        assert self.memory_index % head_dim == 0, \
            "Memory should start at a multiple of head_dim!"
        
        mem_head_start = self.memory_index // head_dim
        #print (mem_head_start)
        
        
        start_shift = 0
        c_fc_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = 1. / config.scale_embeddings * torch.eye(num_partitions)
        #1. / config.scale_embeddings
        c_fc_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] =  torch.eye(num_partitions)
        
        #pass x as well
        c_fc_init[:, mem_head_start: mem_head_start + num_partitions, mem_head_start: mem_head_start + num_partitions] =  torch.eye(num_partitions)
        
        
        #Compute GeLU(x + 1/N \nabla y) - GeLU(x)
        #c_proj_init[din: 2*din, din: 2*din] = config.scale_embeddings * torch.eye(din)
        #c_proj_init[din: 2*din, self.memory_index: self.memory_index + din] = -config.scale_embeddings * torch.eye(din)
        #config.scale_embeddings *
        #config.scale_embeddings *
        c_proj_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = config.scale_embeddings * torch.eye(num_partitions)
        c_proj_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] = -config.scale_embeddings  * torch.eye(num_partitions)
        
        #retain Act (x) for future purposes?
        if retain_og_act:
            c_proj_init[:, mem_head_start: mem_head_start + num_partitions, mem_head_start: mem_head_start + num_partitions] = torch.eye(num_partitions)
        
        
        with torch.no_grad():
            self.c_fc.weight.copy_(torch.swapaxes(c_fc_init, axis0=-1, axis1=-2))
            self.proj_fc.weight.copy_(torch.swapaxes(c_proj_init, axis0=-1, axis1=-2))

        #self.add_module('Layernormback_weights', self.w)
        #self.add_module('Attention_normgates', self.gates)
        #self.add_module('Attention_c_fc', self.c_fc)
        #self.add_module('Attention_proj_fc', self.proj_fc)
        
    def forward(self, hidden_states, position_embeddings, attention_mask=None, activation_memory=None, icl_mask=None):
        output = self.proj_fc ( self.act( self.c_fc(hidden_states) ) )
        return output     
    
